{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.MultiRocketPlus"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MultiRocketPlus"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">MultiRocket: Multiple pooling operators and transformations for fast and effective time series classification.\n",
    "\n",
    "This is a Pytorch implementation of MultiRocket developed by Malcolm McLean and Ignacio Oguiza based on:\n",
    "\n",
    "Tan, C. W., Dempster, A., Bergmeir, C., & Webb, G. I. (2022). MultiRocket: multiple pooling operators and transformations for fast and effective time series classification. Data Mining and Knowledge Discovery, 36(5), 1623-1646.\n",
    "\n",
    "Original paper: https://link.springer.com/article/10.1007/s10618-022-00844-1\n",
    "\n",
    "Original repository:  https://github.com/ChangWeiTan/MultiRocket"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import math\n",
    "from collections import OrderedDict\n",
    "import itertools\n",
    "from tsai.models.layers import rocket_nd_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Flatten(nn.Module):\n",
    "    def forward(self, x): return x.view(x.size(0), -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _LPVV(o, dim=2):\n",
    "    \"Longest stretch of positive values along a dimension(-1, 1)\"\n",
    "\n",
    "    seq_len = o.shape[dim]\n",
    "    binary_tensor = (o > 0).float()\n",
    "\n",
    "    diff = torch.cat([torch.ones_like(binary_tensor.narrow(dim, 0, 1)),\n",
    "                      binary_tensor.narrow(dim, 1, seq_len-1) - binary_tensor.narrow(dim, 0, seq_len-1)], dim=dim)\n",
    "\n",
    "    groups = (diff > 0).cumsum(dim)\n",
    "\n",
    "    # Ensure groups are within valid index bounds\n",
    "    groups = groups * binary_tensor.long()\n",
    "    valid_groups = groups.where(groups < binary_tensor.size(dim), torch.tensor(0, device=groups.device))\n",
    "\n",
    "    counts = torch.zeros_like(binary_tensor).scatter_add_(dim, valid_groups, binary_tensor)\n",
    "\n",
    "    longest_stretch = counts.max(dim)[0]\n",
    "\n",
    "    return torch.nan_to_num(2 * (longest_stretch / seq_len) - 1)\n",
    "\n",
    "def _MPV(o, dim=2):\n",
    "    \"Mean of Positive Values (any positive value)\"\n",
    "    o = torch.where(o > 0, o, torch.nan)\n",
    "    o = torch.nanmean(o, dim)\n",
    "    return torch.nan_to_num(o)\n",
    "\n",
    "def _RSPV(o, dim=2):\n",
    "    \"Relative Sum of Positive Values (-1, 1)\"\n",
    "    o_sum = torch.clamp_min(torch.abs(o).sum(dim), 1e-8)\n",
    "    o_pos_sum = torch.nansum(F.relu(o), dim)\n",
    "    return (o_pos_sum / o_sum) * 2 - 1\n",
    "\n",
    "def _MIPV(o, o_pos, dim=2):\n",
    "    \"Mean of Indices of Positive Values (-1, 1)\"\n",
    "    seq_len = o.shape[dim]\n",
    "    o_arange_shape = [1] * o_pos.ndim\n",
    "    o_arange_shape[dim] = -1\n",
    "    o_arange = torch.arange(o_pos.shape[dim], device=o.device).reshape(o_arange_shape)\n",
    "    o = torch.where(o_pos, o_arange, torch.nan)\n",
    "    o = torch.nanmean(o, dim)\n",
    "    return (torch.nan_to_num(o) / seq_len) * 2 - 1\n",
    "\n",
    "def _PPV(o_pos, dim=2):\n",
    "    \"Proportion of Positive Values (-1, 1)\"\n",
    "    return (o_pos).float().mean(dim) * 2 - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.imports import default_device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[ 0.5644, -0.0509, -0.0390,  0.4091],\n",
      "          [ 0.0517, -0.1471,  0.6458,  0.5593],\n",
      "          [ 0.4516, -0.0821,  0.1271,  0.0592],\n",
      "          [ 0.4151,  0.4376,  0.0763,  0.3780],\n",
      "          [ 0.2653, -0.1817,  0.0156,  0.4993]],\n",
      "\n",
      "         [[-0.0779,  0.0858,  0.1982,  0.3224],\n",
      "          [ 0.1130,  0.0714, -0.1779,  0.5360],\n",
      "          [-0.1848, -0.2270, -0.0925, -0.1217],\n",
      "          [ 0.2820, -0.0205, -0.2777,  0.3755],\n",
      "          [-0.2490,  0.2613,  0.4237,  0.4534]],\n",
      "\n",
      "         [[-0.0162,  0.6368,  0.0016,  0.1467],\n",
      "          [ 0.6035, -0.1365,  0.6930,  0.6943],\n",
      "          [ 0.2790,  0.3818, -0.0731,  0.0167],\n",
      "          [ 0.6442,  0.3443,  0.4829, -0.0944],\n",
      "          [ 0.2932,  0.6952,  0.5541,  0.5946]]],\n",
      "\n",
      "\n",
      "        [[[ 0.6757,  0.5740,  0.3071,  0.4400],\n",
      "          [-0.2344, -0.1056,  0.4773,  0.2432],\n",
      "          [ 0.2595, -0.1528, -0.0866,  0.6201],\n",
      "          [ 0.0657,  0.1220,  0.4849,  0.4254],\n",
      "          [ 0.3399, -0.1609,  0.3465,  0.2389]],\n",
      "\n",
      "         [[-0.0765,  0.0516,  0.0028,  0.4381],\n",
      "          [ 0.5212, -0.2781, -0.0896, -0.0301],\n",
      "          [ 0.6857,  0.3583,  0.5869,  0.3418],\n",
      "          [ 0.3002,  0.5135,  0.6011,  0.6499],\n",
      "          [-0.2807, -0.2888,  0.3965,  0.6585]],\n",
      "\n",
      "         [[-0.1368,  0.6677,  0.1439,  0.1434],\n",
      "          [-0.1820,  0.1041, -0.1211,  0.6103],\n",
      "          [ 0.5808,  0.4588,  0.4572,  0.3713],\n",
      "          [ 0.2389, -0.1392,  0.1371, -0.1570],\n",
      "          [ 0.2840,  0.1214, -0.0059,  0.5064]]]], device='mps:0')\n",
      "tensor([[[ 1.0000, -0.6000,  0.6000,  1.0000],\n",
      "         [-0.6000, -0.2000, -0.6000, -0.2000],\n",
      "         [ 0.6000,  0.2000, -0.2000,  0.2000]],\n",
      "\n",
      "        [[ 0.2000, -0.6000, -0.2000,  1.0000],\n",
      "         [ 0.2000, -0.2000,  0.2000,  0.2000],\n",
      "         [ 0.2000,  0.2000, -0.2000,  0.2000]]], device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "o = torch.rand(2, 3, 5, 4).to(default_device()) - .3\n",
    "print(o)\n",
    "\n",
    "output = _LPVV(o, dim=2)\n",
    "print(output)  # Should print: torch.Size([2, 3, 4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[0.3496, 0.4376, 0.2162, 0.3810],\n",
      "         [0.1975, 0.1395, 0.3109, 0.4218],\n",
      "         [0.4550, 0.5145, 0.4329, 0.3631]],\n",
      "\n",
      "        [[0.3352, 0.3480, 0.4040, 0.3935],\n",
      "         [0.5023, 0.3078, 0.3968, 0.5221],\n",
      "         [0.3679, 0.3380, 0.2460, 0.4079]]], device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "output = _MPV(o, dim=2)\n",
    "print(output)  # Should print: torch.Size([2, 3, 4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 1.0000, -0.0270,  0.9138,  1.0000],\n",
      "         [-0.1286,  0.2568,  0.0630,  0.8654],\n",
      "         [ 0.9823,  0.8756,  0.9190,  0.8779]],\n",
      "\n",
      "        [[ 0.7024,  0.2482,  0.8983,  1.0000],\n",
      "         [ 0.6168,  0.2392,  0.8931,  0.9715],\n",
      "         [ 0.5517,  0.8133,  0.7065,  0.8244]]], device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "output = _RSPV(o, dim=2)\n",
    "print(output)  # Should print: torch.Size([2, 3, 4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[-0.3007, -1.0097, -0.6697, -0.2381],\n",
      "         [-1.0466, -0.9316, -0.9705, -0.3738],\n",
      "         [-0.2786, -0.2314, -0.3366, -0.4569]],\n",
      "\n",
      "        [[-0.5574, -0.8893, -0.3883, -0.2130],\n",
      "         [-0.5401, -0.8574, -0.4009, -0.1767],\n",
      "         [-0.6861, -0.5149, -0.7555, -0.4102]]], device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "output = _PPV(o, dim=2)\n",
    "print(output)  # Should print: torch.Size([2, 3, 4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class MultiRocketFeaturesPlus(nn.Module):\n",
    "    fitting = False\n",
    "\n",
    "    def __init__(self, c_in, seq_len, num_features=10_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=9, max_num_kernels=84, diff=False):\n",
    "        super(MultiRocketFeaturesPlus, self).__init__()\n",
    "\n",
    "        self.c_in, self.seq_len = c_in, seq_len\n",
    "        self.kernel_size, self.max_num_channels = kernel_size, max_num_channels\n",
    "\n",
    "        # Kernels\n",
    "        indices, pos_values = self.get_indices(kernel_size, max_num_kernels)\n",
    "        self.num_kernels = len(indices)\n",
    "        kernels = (-torch.ones(self.num_kernels, 1, self.kernel_size)).scatter_(2, indices, pos_values)\n",
    "        self.indices = indices\n",
    "        self.kernels = nn.Parameter(kernels.repeat(c_in, 1, 1), requires_grad=False)\n",
    "        num_features = num_features // 4\n",
    "        self.num_features = num_features // self.num_kernels * self.num_kernels\n",
    "        self.max_dilations_per_kernel = max_dilations_per_kernel\n",
    "\n",
    "        # Dilations\n",
    "        self.set_dilations(seq_len)\n",
    "\n",
    "        # Channel combinations (multivariate)\n",
    "        if c_in > 1:\n",
    "            self.set_channel_combinations(c_in, max_num_channels)\n",
    "\n",
    "        # Bias\n",
    "        for i in range(self.num_dilations):\n",
    "            self.register_buffer(f'biases_{i}', torch.empty(\n",
    "                (self.num_kernels, self.num_features_per_dilation[i])))\n",
    "        self.register_buffer('prefit', torch.BoolTensor([False]))\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        _features = []\n",
    "        for i, (dilation, padding) in enumerate(zip(self.dilations, self.padding)):\n",
    "            _padding1 = i % 2\n",
    "\n",
    "            # Convolution\n",
    "            C = F.conv1d(x, self.kernels, padding=padding,\n",
    "                         dilation=dilation, groups=self.c_in)\n",
    "            if self.c_in > 1:  # multivariate\n",
    "                C = C.reshape(x.shape[0], self.c_in, self.num_kernels, -1)\n",
    "                channel_combination = getattr(\n",
    "                    self, f'channel_combinations_{i}')\n",
    "                C = torch.mul(C, channel_combination)\n",
    "                C = C.sum(1)\n",
    "\n",
    "            # Bias\n",
    "            if not self.prefit or self.fitting:\n",
    "                num_features_this_dilation = self.num_features_per_dilation[i]\n",
    "                bias_this_dilation = self.get_bias(\n",
    "                    C, num_features_this_dilation)\n",
    "                setattr(self, f'biases_{i}', bias_this_dilation)\n",
    "                if self.fitting:\n",
    "                    if i < self.num_dilations - 1:\n",
    "                        continue\n",
    "                    else:\n",
    "                        self.prefit = torch.BoolTensor([True])\n",
    "                        return\n",
    "                elif i == self.num_dilations - 1:\n",
    "                    self.prefit = torch.BoolTensor([True])\n",
    "            else:\n",
    "                bias_this_dilation = getattr(self, f'biases_{i}')\n",
    "\n",
    "            # Features\n",
    "            _features.append(self.apply_pooling_ops(\n",
    "                C[:, _padding1::2], bias_this_dilation[_padding1::2]))\n",
    "            _features.append(self.apply_pooling_ops(\n",
    "                C[:, 1-_padding1::2, padding:-padding], bias_this_dilation[1-_padding1::2]))\n",
    "\n",
    "        return torch.cat(_features, dim=1)\n",
    "\n",
    "    def fit(self, X, chunksize=None):\n",
    "        num_samples = X.shape[0]\n",
    "        if chunksize is None:\n",
    "            chunksize = min(num_samples, self.num_dilations * self.num_kernels)\n",
    "        else:\n",
    "            chunksize = min(num_samples, chunksize)\n",
    "        idxs = np.random.choice(num_samples, chunksize, False)\n",
    "        self.fitting = True\n",
    "        if isinstance(X, np.ndarray):\n",
    "            self(torch.from_numpy(X[idxs]).to(self.kernels.device))\n",
    "        else:\n",
    "            self(X[idxs].to(self.kernels.device))\n",
    "        self.fitting = False\n",
    "\n",
    "    def apply_pooling_ops(self, C, bias):\n",
    "        C = C.unsqueeze(-1)\n",
    "        bias = bias.view(1, bias.shape[0], 1, bias.shape[1])\n",
    "        pos_vals = (C > bias)\n",
    "        ppv = _PPV(pos_vals).flatten(1)\n",
    "        mpv = _MPV(C - bias).flatten(1)\n",
    "        # rspv = _RSPV(C - bias).flatten(1)\n",
    "        mipv = _MIPV(C, pos_vals).flatten(1)\n",
    "        lspv = _LPVV(pos_vals).flatten(1)\n",
    "        return torch.cat((ppv, mpv, mipv, lspv), dim=1)\n",
    "        return torch.cat((ppv, rspv, mipv, lspv), dim=1)\n",
    "\n",
    "    def set_dilations(self, input_length):\n",
    "        num_features_per_kernel = self.num_features // self.num_kernels\n",
    "        true_max_dilations_per_kernel = min(\n",
    "            num_features_per_kernel, self.max_dilations_per_kernel)\n",
    "        multiplier = num_features_per_kernel / true_max_dilations_per_kernel\n",
    "        max_exponent = np.log2((input_length - 1) / (self.kernel_size - 1))\n",
    "        dilations, num_features_per_dilation = \\\n",
    "            np.unique(np.logspace(0, max_exponent, true_max_dilations_per_kernel, base=2).astype(\n",
    "                np.int32), return_counts=True)\n",
    "        num_features_per_dilation = (\n",
    "            num_features_per_dilation * multiplier).astype(np.int32)\n",
    "        remainder = num_features_per_kernel - num_features_per_dilation.sum()\n",
    "        i = 0\n",
    "        while remainder > 0:\n",
    "            num_features_per_dilation[i] += 1\n",
    "            remainder -= 1\n",
    "            i = (i + 1) % len(num_features_per_dilation)\n",
    "        self.num_features_per_dilation = num_features_per_dilation\n",
    "        self.num_dilations = len(dilations)\n",
    "        self.dilations = dilations\n",
    "        self.padding = []\n",
    "        for i, dilation in enumerate(dilations):\n",
    "            self.padding.append((((self.kernel_size - 1) * dilation) // 2))\n",
    "\n",
    "    def set_channel_combinations(self, num_channels, max_num_channels):\n",
    "        num_combinations = self.num_kernels * self.num_dilations\n",
    "        if max_num_channels:\n",
    "            max_num_channels = min(num_channels, max_num_channels)\n",
    "        else:\n",
    "            max_num_channels = num_channels\n",
    "        max_exponent_channels = np.log2(max_num_channels + 1)\n",
    "        num_channels_per_combination = (\n",
    "            2 ** np.random.uniform(0, max_exponent_channels, num_combinations)).astype(np.int32)\n",
    "        self.num_channels_per_combination = num_channels_per_combination\n",
    "        channel_combinations = torch.zeros(\n",
    "            (1, num_channels, num_combinations, 1))\n",
    "        for i in range(num_combinations):\n",
    "            channel_combinations[:, np.random.choice(\n",
    "                num_channels, num_channels_per_combination[i], False), i] = 1\n",
    "        channel_combinations = torch.split(\n",
    "            channel_combinations, self.num_kernels, 2)  # split by dilation\n",
    "        for i, channel_combination in enumerate(channel_combinations):\n",
    "            self.register_buffer(\n",
    "                f'channel_combinations_{i}', channel_combination)  # per dilation\n",
    "\n",
    "    def get_quantiles(self, n):\n",
    "        return torch.tensor([(_ * ((np.sqrt(5) + 1) / 2)) % 1 for _ in range(1, n + 1)]).float()\n",
    "\n",
    "    def get_bias(self, C, num_features_this_dilation):\n",
    "        isp = torch.randint(C.shape[0], (self.num_kernels,))\n",
    "        samples = C[isp].diagonal().T\n",
    "        biases = torch.quantile(samples, self.get_quantiles(\n",
    "            num_features_this_dilation).to(C.device), dim=1).T\n",
    "        return biases\n",
    "\n",
    "    def get_indices(self, kernel_size, max_num_kernels):\n",
    "        num_pos_values = math.ceil(kernel_size / 3)\n",
    "        num_neg_values = kernel_size - num_pos_values\n",
    "        pos_values = num_neg_values / num_pos_values\n",
    "        if kernel_size > 9:\n",
    "            random_kernels = [np.sort(np.random.choice(kernel_size, num_pos_values, False)).reshape(\n",
    "                1, -1) for _ in range(max_num_kernels)]\n",
    "            indices = torch.from_numpy(\n",
    "                np.concatenate(random_kernels, 0)).unsqueeze(1)\n",
    "        else:\n",
    "            indices = torch.LongTensor(list(itertools.combinations(\n",
    "                np.arange(kernel_size), num_pos_values))).unsqueeze(1)\n",
    "            if max_num_kernels and len(indices) > max_num_kernels:\n",
    "                indices = indices[np.sort(np.random.choice(\n",
    "                    len(indices), max_num_kernels, False))]\n",
    "        return indices, pos_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class MultiRocketBackbonePlus(nn.Module):\n",
    "    def __init__(self, c_in, seq_len, num_features=50_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=None, max_num_kernels=84, use_diff=True):\n",
    "        super(MultiRocketBackbonePlus, self).__init__()\n",
    "\n",
    "        num_features_per_branch = num_features // (1 + use_diff)\n",
    "        self.branch_x = MultiRocketFeaturesPlus(c_in, seq_len, num_features=num_features_per_branch, max_dilations_per_kernel=max_dilations_per_kernel,\n",
    "                                                kernel_size=kernel_size, max_num_channels=max_num_channels, max_num_kernels=max_num_kernels)\n",
    "        if use_diff:\n",
    "            self.branch_x_diff = MultiRocketFeaturesPlus(c_in, seq_len - 1, num_features=num_features_per_branch, max_dilations_per_kernel=max_dilations_per_kernel,\n",
    "                                                         kernel_size=kernel_size, max_num_channels=max_num_channels, max_num_kernels=max_num_kernels)\n",
    "        if use_diff:\n",
    "            self.num_features = (self.branch_x.num_features + self.branch_x_diff.num_features) * 4 # 4 types of features\n",
    "        else:\n",
    "            self.num_features = self.branch_x.num_features * 4\n",
    "        self.use_diff = use_diff\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.use_diff:\n",
    "            x_features = self.branch_x(x)\n",
    "            x_diff_features = self.branch_x(torch.diff(x))\n",
    "            output = torch.cat([x_features, x_diff_features], dim=-1)\n",
    "            return output\n",
    "        else:\n",
    "            output = self.branch_x(x)\n",
    "            return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class MultiRocketPlus(nn.Sequential):\n",
    "\n",
    "    def __init__(self, c_in, c_out, seq_len, d=None, num_features=50_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=None, max_num_kernels=84,\n",
    "                 use_bn=True, fc_dropout=0, custom_head=None, zero_init=True, use_diff=True):\n",
    "\n",
    "        # Backbone\n",
    "        backbone = MultiRocketBackbonePlus(c_in, seq_len, num_features=num_features, max_dilations_per_kernel=max_dilations_per_kernel,\n",
    "                                          kernel_size=kernel_size, max_num_channels=max_num_channels, max_num_kernels=max_num_kernels, use_diff=use_diff)\n",
    "        num_features = backbone.num_features\n",
    "\n",
    "        # Head\n",
    "        self.head_nf = num_features\n",
    "        if custom_head is not None:\n",
    "            if isinstance(custom_head, nn.Module): head = custom_head\n",
    "            else: head = custom_head(self.head_nf, c_out, 1)\n",
    "        elif d is not None:\n",
    "            head = rocket_nd_head(num_features, c_out, seq_len=None, d=d, use_bn=use_bn, fc_dropout=fc_dropout, zero_init=zero_init)\n",
    "        else:\n",
    "            layers = [Flatten()]\n",
    "            if use_bn:\n",
    "                layers += [nn.BatchNorm1d(num_features)]\n",
    "            if fc_dropout:\n",
    "                layers += [nn.Dropout(fc_dropout)]\n",
    "            linear = nn.Linear(num_features, c_out)\n",
    "            if zero_init:\n",
    "                nn.init.constant_(linear.weight.data, 0)\n",
    "                nn.init.constant_(linear.bias.data, 0)\n",
    "            layers += [linear]\n",
    "            head = nn.Sequential(*layers)\n",
    "\n",
    "        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))\n",
    "\n",
    "MultiRocket = MultiRocketPlus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.imports import default_device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 3])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xb = torch.randn(16, 5, 20).to(default_device())\n",
    "yb = torch.randint(0, 3, (16, 20)).to(default_device())\n",
    "\n",
    "model = MultiRocketPlus(5, 3, 20, d=None, use_diff=True).to(default_device())\n",
    "output = model(xb)\n",
    "assert output.shape == (16, 3)\n",
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 3])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xb = torch.randn(16, 5, 20).to(default_device())\n",
    "yb = torch.randint(0, 3, (16, 20)).to(default_device())\n",
    "\n",
    "model = MultiRocketPlus(5, 3, 20, d=None, use_diff=False).to(default_device())\n",
    "output = model(xb)\n",
    "assert output.shape == (16, 3)\n",
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 20, 3])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xb = torch.randn(16, 5, 20).to(default_device())\n",
    "yb = torch.randint(0, 3, (16, 5, 20)).to(default_device())\n",
    "\n",
    "model = MultiRocketPlus(5, 3, 20, d=20, use_diff=True).to(default_device())\n",
    "output = model(xb)\n",
    "assert output.shape == (16, 20, 3)\n",
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/076_models.MultiRocketPlus.ipynb saved at 2024-02-11 10:53:13\n",
      "Correct notebook to script conversion! 😃\n",
      "Sunday 11/02/24 10:53:16 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
